import ptan
import numpy as np
import torch
import torch.nn as nn

HID_SIZE = 128


class ModelActor(nn.Module):
    def __init__(self, obs_action, act_size):
        '''
        :param obs_size: 观测的环境维度
        :param act_size: 执行的动作的维度
        '''
        super(ModelActor, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(obs_action[0], 64, kernel_size=8, stride=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(obs_action)
        self.mu = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, act_size)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))


    def forward(self, x):
        x = x / 255.0
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.mu(conv_out)


class ModelCritic(nn.Module):
    '''
    trop信赖域策略优化评价网络
    ACKTR算法中使用的critic网络
    ppt优化评价网络
    '''
    def __init__(self, obs_size):
        super(ModelCritic, self).__init__()


        self.conv = nn.Sequential(
            nn.Conv2d(obs_size[0], 64, kernel_size=8, stride=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(obs_size)
        self.value = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )


    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))


    def forward(self, x):
        x = x.float() / 255.0
        conv_out = self.conv(x).view(x.size()[0], -1)

        return self.value(conv_out)


class AgentA2C(ptan.agent.BaseAgent):
    '''
    创建代理器
    '''
    def __init__(self, net, device="cpu"):
        self.net = net
        self.device = device

    def __call__(self, states, agent_states):
        '''
        states: 观测的环境状态
        agent_states：智能体自己的状态，在这里是没有使用的
        '''
        # 创建环境预处理器，将环境状态转换为float32类型
        states_v = ptan.agent.float32_preprocessor(states).to(self.device)

        # 通过环境状态预测执行的动作
        mu_v = self.net(states_v)
        mu = mu_v.data.cpu().numpy()
        logstd = self.net.logstd.data.cpu().numpy()
        # 该动作的作用，是对预测的动作添加随机噪音，实现动作的探索
        actions = mu + np.exp(logstd) * np.random.normal(size=logstd.shape)
        # 将执行的动作压缩到-1到1中，可能是因为输入给网络的值不能超过-1和1
        actions = np.clip(actions, -1, 1)
        return actions, agent_states
